# Copyright Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from unittest.mock import MagicMock
import unittest
import pytest

from testing_utils.mock import mock_kia_library
mock_kia_library()

from msmodelslim.pytorch.llm_ptq.model.hunyuan.hunyuan_video import HunyuanVideoAdapter
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig


class TestGetNormLinearSubgraph(unittest.TestCase):
    """测试 get_norm_linear_subgraph 函数的不同场景"""

    @staticmethod
    def create_adapter(num_layers):
        """创建模拟的预训练模型"""
        model = MagicMock()
        model.config = MagicMock()
        model.config.mm_double_blocks_depth = num_layers
        adapter = HunyuanVideoAdapter(model)
        return adapter

    def test_min_layer(self):
        """测试边界条件：模型最少层数"""
        adapter = self.create_adapter(1)
        result = adapter.get_norm_linear_subgraph(AntiOutlierConfig(anti_method='m4'))

        self.assertEqual(len(result), 2)
        key0_img = '0img_qkv_anti'
        self.assertIn(key0_img, result)
        self.assertEqual(len(result[key0_img]), 1)
        self.assertIn('double_blocks.0.img_attn_qkv', result[key0_img])

        key0_txt = '0txt_qkv_anti'
        self.assertIn(key0_txt, result)
        self.assertEqual(len(result[key0_txt]), 1)
        self.assertIn('double_blocks.0.txt_attn_qkv', result[key0_txt])

    def test_mm_double_blocks_depth_not_integer(self):
        """测试 mm_double_blocks_depth 不是整数"""
        adapter = self.create_adapter("invalid")
        with self.assertRaises(TypeError) as context:
            adapter.get_norm_linear_subgraph(AntiOutlierConfig(anti_method='m4'))
        self.assertEqual(str(context.exception), "mm_double_blocks_depth must be an integer.")

    def test_invalid_layers_zero(self):
        """测试边界条件：模型层数小于条件值"""
        adapter = self.create_adapter(0)
        with self.assertRaises(ValueError) as context:
            adapter.get_norm_linear_subgraph(AntiOutlierConfig(anti_method='m4'))
        self.assertEqual(str(context.exception), "mm_double_blocks_depth must be in the range 1 to 999.")

    def test_invalid_layers_exceed(self):
        """测试边界条件：模型层数大于条件值"""
        adapter = self.create_adapter(1000)
        with self.assertRaises(ValueError) as context:
            adapter.get_norm_linear_subgraph(AntiOutlierConfig(anti_method='m4'))
        self.assertEqual(str(context.exception), "mm_double_blocks_depth must be in the range 1 to 999.")


class TestModifySmoothArgs(unittest.TestCase):
    """测试 modify_smooth_args 函数的不同场景"""

    @staticmethod
    def create_adapter(num_layers):
        """创建模拟的预训练模型"""
        model = MagicMock()
        model.config = MagicMock()
        model.config.mm_double_blocks_depth = num_layers
        adapter = HunyuanVideoAdapter(model)
        return adapter

    @staticmethod
    def create_anti_config(method='m4'):
        cfg = AntiOutlierConfig(anti_method=method)
        return cfg

    def test_m4_method(self):
        """测试 anti_method 为 'm4' 时的逻辑"""
        adapter = self.create_adapter(3)
        cfg = self.create_anti_config()
        norm_name = "layer1"
        linear_names = "layer1_linear"
        args = [1, 2, 3]
        kwargs = {'param1': 'value1'}

        updated_args, updated_kwargs = adapter.modify_smooth_args(cfg, norm_name, linear_names, args, kwargs)

        # 验证 args 未被修改
        self.assertEqual(updated_args, args)
        # 验证 kwargs 是否正确更新
        self.assertEqual(updated_kwargs['is_shift'], False)
        self.assertEqual(updated_kwargs['alpha'], cfg.alpha)
        self.assertIn('param1', updated_kwargs)
        self.assertEqual(updated_kwargs['param1'], 'value1')

        assert kwargs["is_shift"] is False
        assert kwargs["alpha"] == cfg.alpha

    def test_non_m4_method(self):
        """测试 anti_method 不为 'm4' 时的逻辑"""
        adapter = self.create_adapter(3)
        cfg = self.create_anti_config(method='m1')
        norm_name = "layer1"
        linear_names = "layer1_linear"
        args = [1, 2, 3]
        kwargs = {'param1': 'value1'}

        updated_args, updated_kwargs = adapter.modify_smooth_args(cfg, norm_name, linear_names, args, kwargs)

        # 验证 args 未被修改
        self.assertEqual(updated_args, args)
        # 验证 kwargs 未被修改
        self.assertEqual(updated_kwargs, kwargs)

    def test_override_existing_kwargs(self):
        """测试覆盖已存在的 kwargs 参数"""
        adapter = self.create_adapter(3)
        cfg = self.create_anti_config()
        norm_name = "norm.bias"
        linear_names = "dummy"
        args = []
        kwargs = {"is_shift": True, "alpha": 1.0}

        args, kwargs = adapter.modify_smooth_args(cfg, norm_name, linear_names, args, kwargs)

        assert kwargs["is_shift"] is False  # 应覆盖原有值
        assert kwargs["alpha"] == cfg.alpha  # 应覆盖原有值
